{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "403d179c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from neo4j import GraphDatabase, basic_auth\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import pandas as pd\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80ee9a49",
   "metadata": {},
   "source": [
    "## Set number of questions to generate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "id": "fa80e37b",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_QUESTIONS = 100\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac046718",
   "metadata": {},
   "source": [
    "## Load KG credentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "id": "8d41be45",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
    "username = os.environ.get('NEO4J_USER')\n",
    "password = os.environ.get('NEO4J_PSW')\n",
    "url = os.environ.get('NEO4J_URI')\n",
    "database = os.environ.get('NEO4J_DB')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf3354e7",
   "metadata": {},
   "source": [
    "## Load disease names stored in vectorDB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "id": "2ec9d667",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/disease_with_relation_to_genes.pickle', 'rb') as f:\n",
    "    disease_names = pickle.load(f)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "654a9a58",
   "metadata": {},
   "source": [
    "## Extract GWAS Disease-Gene relation from the KG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "c280e781",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 158 ms, sys: 19.6 ms, total: 178 ms\n",
      "Wall time: 550 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "auth = basic_auth(username, password)\n",
    "sdb = GraphDatabase.driver(url, auth=auth)\n",
    "\n",
    "gwas_query = '''\n",
    "    MATCH (d:Disease)-[r:ASSOCIATES_DaG]->(g:Gene)\n",
    "    WHERE r.sources = ['GWAS']\n",
    "    WITH d, g, r.gwas_pvalue AS pvalue\n",
    "    ORDER BY pvalue\n",
    "    WITH d, COLLECT(g)[0] AS gene_with_lowest_pvalue, pvalue\n",
    "    RETURN d.name AS disease_name, gene_with_lowest_pvalue.name AS gene_name, pvalue\n",
    "'''\n",
    "\n",
    "with sdb.session() as session:\n",
    "    with session.begin_transaction() as tx:\n",
    "        result = tx.run(gwas_query)\n",
    "        out_list = []\n",
    "        for row in result:\n",
    "            out_list.append((row['disease_name'], row['gene_name'], row['pvalue']))\n",
    "\n",
    "gwas_disease_names = pd.DataFrame(out_list, columns=['disease_name', 'gene_name', 'gwas_pvalue']).drop_duplicates()\n",
    "sdb.close()\n",
    "\n",
    "gwas_disease_names = gwas_disease_names[gwas_disease_names.disease_name.isin(disease_names)]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0db2757f",
   "metadata": {},
   "source": [
    "## Create test questions from the extracted relationships"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "9fe85753",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 97.3 ms, sys: 1.08 ms, total: 98.4 ms\n",
      "Wall time: 97.7 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "template_questions = [\n",
    "    'Is {} associated with {}?',\n",
    "    'What is the GWAS p-value for the association between {} and {}?'\n",
    "]\n",
    "\n",
    "test_questions = []\n",
    "random.seed(42)\n",
    "for index,row in gwas_disease_names.iterrows():\n",
    "    selected_question = random.choice(template_questions)\n",
    "    if random.random() < 0.5:\n",
    "        test_questions.append(selected_question.format(row['disease_name'], row['gene_name']))\n",
    "    else:\n",
    "        test_questions.append(selected_question.format(row['gene_name'], row['disease_name']))\n",
    "\n",
    "gwas_disease_names.loc[:,'question'] = test_questions\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f1800f5",
   "metadata": {},
   "source": [
    "## Create perturbed test questions (lower case names) from the extracted relationships"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "c788c8d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 96 ms, sys: 962 µs, total: 97 ms\n",
      "Wall time: 96.3 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "template_questions = [\n",
    "    'Is {} associated with {}?',\n",
    "    'What is the GWAS p-value for the association between {} and {}?'\n",
    "]\n",
    "\n",
    "test_questions_perturbed = []\n",
    "random.seed(42)\n",
    "for index,row in gwas_disease_names.iterrows():\n",
    "    selected_question = random.choice(template_questions)\n",
    "    if random.random() < 0.5:\n",
    "        test_questions_perturbed.append(selected_question.format(row['disease_name'].lower(), row['gene_name'].lower()))\n",
    "    else:\n",
    "        test_questions_perturbed.append(selected_question.format(row['gene_name'].lower(), row['disease_name'].lower()))\n",
    "\n",
    "gwas_disease_names.loc[:,'question_perturbed'] = test_questions_perturbed\n",
    "\n",
    "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06eed996",
   "metadata": {},
   "source": [
    "## Save the test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "7f02bb5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n",
    "\n",
    "gwas_disease_names_selected.to_csv('../data/rag_comparison_data.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea680eb0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
